library(pacman)
p_load(
  here, fst, data.table, ggplot2, collapse, lubridate, yardstick
)
source('R/02-power-plants/predict-production.r')

# Step 1: Get load data for 2019 and 2020 -------------------------------------
  all_load_dt = 
    read.fst(
      path = here(
        "Data/electricity-generation/clean-load/nerc-load-dt-oge.fst"
      ),
      as.data.table = TRUE
    )[!(nerc_adj %in% c('AK','ASCC','HI')),
      .(nerc_adj = str_to_lower(nerc_adj), datetime_utc, excess_load)
    ]
  all_load_dt[,    
    ic := fcase(
      nerc_adj == 'tre', 'Texas',
      nerc_adj %in% c('wecc','cal'), 'West',
      default = 'East'
    ) |> factor(levels = c('West','Texas','East'))]
  load_dt = prep_load_table(all_load_dt, id_var = 'datetime_utc')

# Step 2: Generate predictions from the model ---------------------------------
  # Prepping data
  set.seed(218)
  model_tables = prep_model_table()
  plant_dt = prep_plant_table(model_tables$model_dt)
  epsilon_dt = generate_epsilons(plant_dt)  
  nerc_fit_path = 'Data/electricity-generation/plant-model-fit/predict-production'
  dir.create(here(nerc_fit_path), showWarnings = FALSE)
  nerc_already_run = list.files(here(nerc_fit_path)) |> 
    str_extract('cal|wecc|tre|serc|mro|npcc|rfc')|>
    str_to_upper() |>
    na.omit()
  # Running for each region 
  nerc_to_run = plant_dt[!(nerc_adj %in% nerc_already_run)]$nerc_adj |> unique()
  map(
    nerc_to_run,
    predict_production_region,
    load_dt = load_dt, 
    plant_dt = plant_dt, 
    model_tables = model_tables, 
    epsilon_dt = epsilon_dt,
    path = nerc_fit_path,  
    add_actual_data = TRUE, 
    id_var = 'datetime_utc'
  )
  # Updating the already run list
  nerc_already_run = list.files(here(nerc_fit_path)) |> 
    str_extract('cal|wecc|tre|serc|mro|npcc|rfc')|>
    str_to_upper() |>
    na.omit()

# Step 4: Summarize and visualize fit -----------------------------------------
  # Function to aggregate fitted production and emissions 
  agg_fit = function(dt_in, by_cols){
    dt_in[,.(
      net_generation_mwh = sum(net_generation_mwh),
      fit_net_generation_mwh = sum(fit_net_generation_mwh),
      emissions_tons_co2e = sum(emissions_tons_co2e)/2000,
      fit_emissions_tons_co2e = sum(fit_emissions_tons_co2e),
      emissions_tons_so2 = sum(emissions_tons_so2)/2000,
      fit_emissions_tons_so2 = sum(fit_emissions_tons_so2),
      emissions_tons_nox = sum(emissions_tons_nox)/2000,
      fit_emissions_tons_nox = sum(fit_emissions_tons_nox),
      prod_hours = mean(net_generation_mwh > 0),
      fit_prod_hours = mean(fit_net_generation_mwh > 0)
      ), 
      keyby = by_cols
    ]
  }
  # Reading in and aggregating to nerc and fuel category
  nerc_fuel_gen_dt = 
    map(
      nerc_already_run,
      \(x){
        read.fst(
          path = here(
            nerc_fit_path, 
            paste0(str_to_lower(x),'.fst')
          ),
          as.data.table = TRUE
        ) |>
        merge(
          plant_dt, 
          by = 'plant_id_eia'
        ) |>
        setnames(
          old = c('co2e_mass_lb_for_electricity','so2_mass_lb_for_electricity','nox_mass_lb_for_electricity'),
          new = c('emissions_tons_co2e','emissions_tons_so2','emissions_tons_nox')
        ) |>
        agg_fit(by_cols = c('fuel_category','datetime_utc','year')) %>%
        .[,nerc_adj := x]
      }
    ) |> rbindlist()
  # Adding interconnection and season variables to table
  nerc_fuel_gen_dt[,':='(
    ic = fcase(
      nerc_adj == 'TRE', 'Texas',
      nerc_adj %in% c('WECC','CAL'), 'West',
      default = 'East'
    ) |> factor(levels = c('West','Texas','East')),
    season = fcase(
      month(datetime_utc) %in% 3:5, 'Spring',
      month(datetime_utc) %in% 6:8, 'Summer',
      month(datetime_utc) %in% 9:11, 'Fall',
      default = 'Winter'
    ) |>
    factor(levels = c('Winter','Spring','Summer','Fall')),
    fuel_category_n = 
      fcase(
        fuel_category == 'natural_gas', 'Natural Gas',
        fuel_category == 'coal', 'Coal',
        fuel_category == 'hydro', 'Hydro',
        fuel_category == 'nuclear', 'Nuclear',
        #fuel_category == 'geothermal', 'Geothermal',
        default = 'Other'
      ) |>
      factor(levels = c('Natural Gas','Coal', 'Nuclear','Hydro','Other'))
  )]
  # Replacing actual data with data from elec-gen-dt
  get_actual_generation_nerc = function(x){
    # Reading raw data
    elec_gen_dt = 
      read.fst(
        path = here(
          'Data/electricity-generation/elec-gen-dt',
          paste0('elec-gen-dt-',str_to_lower(x),'.fst')
        ),
        as.data.table = TRUE
      )[!(fuel_category %in% c('wind','solar'))]
    # Checking for positive production and variance 
    elec_gen_dt[,
      ':='(
        tot_gen = sum(net_generation_mwh), 
        var_gen = var(net_generation_mwh)
      ),
      by = .(plant_id_eia)
    ]
    # Aggregating to nerc region+fuel cat
    nerc_fuel_actual_dt = 
      elec_gen_dt[tot_gen > 0 & var_gen > 0,
        .(net_generation_mwh = sum(net_generation_mwh)),
        keyby = .(nerc_adj, fuel_category, datetime_utc)
      ]
    return(nerc_fuel_actual_dt)
  }
  nerc_fuel_actual_dt = 
    map(nerc_already_run, get_actual_generation_nerc) |> 
    rbindlist()
  nerc_fuel_gen_dt = 
    merge(
      nerc_fuel_gen_dt,
      nerc_fuel_actual_dt,
      by = c('nerc_adj','fuel_category','datetime_utc'),
      all.x = TRUE
    )
  # Note: net_generation_mwh uses the same plants as in 2019. 
  #   net_generation_mwh_all_2020_plants does not restrict to same plants
  setnames(
    nerc_fuel_gen_dt,
    old = c('net_generation_mwh.x','net_generation_mwh.y'),
    new = c('net_generation_mwh', 'net_generation_mwh_all_2020_plants')
  )
  # Other aggregations 
  nerc_gen_dt = agg_fit(
    nerc_fuel_gen_dt,
    by_cols = c('nerc_adj','datetime_utc','year','season')
  )
  ic_fuel_gen_dt = agg_fit(
    nerc_fuel_gen_dt,
    by_cols = c('ic','fuel_category_n','datetime_utc','year','season')
  )
  ic_gen_dt = agg_fit(
    nerc_fuel_gen_dt,
    by_cols = c('ic','datetime_utc','year','season')
  )


# Simple predicted vs actual with smoothing -----------------------------------
  ic_gen_dt |> 
  dplyr::group_by(ic, year(datetime_utc))|>
  yardstick::rsq(
    truth = net_generation_mwh, 
    estimate = fit_net_generation_mwh
  )
  # For the interconnect
  model_fit_interconnect_p = 
    ggplot(
      ic_gen_dt[datetime_utc > ymd_hms('2020-01-01 07:00:00')],
      aes(
        x = net_generation_mwh/1e3, 
        y = fit_net_generation_mwh/1e3
      )
    ) + 
    geom_point(alpha = 0.01, color = 'gray30', shape = 19) +
    geom_smooth() +
    geom_abline(intercept = 0, slope = 1, linetype = "dashed") +
    facet_wrap(
      ~ic,
      scales = "free"
    )+ 
    labs(
      x = "Actual Production (GWh)",
      y = "Predicted Production (GWh)"
    ) + 
    theme_minimal() 
  #model_fit_interconnect_p
  ggsave(
    plot = model_fit_interconnect_p, 
    filename = here('figures/electricity-generation/model-fit-interconnect-2020.jpeg'),
    bg = 'white',
    width = 8, height = 3, units = 'in'
  )
  model_fit_region_p = 
    ggplot(
      nerc_gen_dt[nerc_adj !='TRE' & datetime_utc > ymd_hms('2020-01-01 07:00:00')], 
      aes(
        x = net_generation_mwh/1e3, 
        y = fit_net_generation_mwh/1e3
      )
    ) + 
    geom_point(alpha = 0.01, color = 'gray30', shape = 19) +
    geom_smooth() +
    geom_abline(intercept = 0, slope = 1, linetype = "dashed") +
    facet_wrap(
      ~nerc_adj,
      scales = "free"
    )+ 
    labs(
      x = "Actual Production (GWh)",
      y = "Predicted Production (GWh)"
    ) + 
    theme_minimal() 
  ggsave(
    plot = model_fit_region_p, 
    filename = here('figures/electricity-generation/model-fit-region-2020.jpeg'),
    bg = 'white',
    width = 8, height = 5, units = 'in'
  )

# Predicted vs actual by hour and season --------------------------------------
  p_load(lubridate)
  # Plot of hour of day vs season for interconnection
  model_fit_ic_hour_season_p =
    ic_gen_dt[datetime_utc > ymd_hms('2020-01-01 07:00:00'),.(
      Actual = mean(net_generation_mwh),
      Predicted = mean(fit_net_generation_mwh)), 
      by = .(ic, hour = hour(datetime_utc - hours(5)), season, year 
    )] |>
    melt(id.vars = c('ic','hour','season','year'))|>
    ggplot(aes(
      x = hour, 
      y = value/1e3, 
      linetype = variable, 
      color = variable
    )) +
    geom_line(linewidth = 1.1) + 
    facet_grid(
      cols = vars(season), 
      rows = vars(ic), 
      scales = "free"
    ) +
    theme_minimal() + 
    scale_color_brewer(
      labels = c('Actual','Predicted'),
      palette = "Dark2"
    ) + 
    scale_linetype_manual(
      labels = c('Actual','Predicted'),
      values = c('twodash','solid')
    )+
    labs(
      x = "Hour (EST)", 
      y = "Electricity Production (GWh)",
      linetype = '',color = ''
    ) + 
    theme(legend.position = 'bottom')
  #model_fit_ic_hour_season_p
  ggsave(
    plot = model_fit_ic_hour_season_p,
    filename = here("figures/electricity-generation/model-fit-ic-hour-season-2020.jpeg"),
    bg = 'white',
    width = 8, height = 6
  )
  # Plot of hour of day vs season for interconnection
  model_fit_ic_hour_season_y_zero_p =
    ic_gen_dt[datetime_utc > ymd_hms('2020-01-01 07:00:00'),.(
      Actual = mean(net_generation_mwh),
      Predicted = mean(fit_net_generation_mwh)), 
      by = .(ic, hour = hour(datetime_utc - hours(5)), season, year 
    )] |>
    melt(id.vars = c('ic','hour','season','year'))|>
    ggplot(aes(
      x = hour, 
      y = value/1e3, 
      linetype = variable, 
      color = variable
    )) +
    geom_line(linewidth = 1.1) + 
    facet_grid(
      cols = vars(season), 
      rows = vars(ic), 
      scales = "free"
    ) +
    theme_minimal() + 
    scale_color_brewer(
      labels = c('Actual','Predicted'),
      palette = "Dark2"
    ) + 
    scale_linetype_manual(
      labels = c('Actual','Predicted'),
      values = c('twodash','solid')
    )+
    labs(
      x = "Hour (EST)", 
      y = "Electricity Production (GWh)",
      linetype = '',color = ''
    ) + 
    ylim(0, NA) + 
    theme(legend.position = 'bottom')
  #model_fit_ic_hour_season_y_zero_p
  ggsave(
    plot = model_fit_ic_hour_season_y_zero_p,
    filename = here("figures/electricity-generation/model-fit-ic-hour-season-2020-yaxis-zero.jpeg"),
    bg = 'white',
    width = 8, height = 6
  )
  # Doing the same for regions now 
  model_fit_region_hour_season_p =
    nerc_gen_dt[nerc_adj !='TRE' & datetime_utc > ymd_hms('2020-01-01 07:00:00'),.(
      Actual = mean(net_generation_mwh),
      Predicted = mean(fit_net_generation_mwh)), 
      by = .(nerc_adj, hour = hour(datetime_utc - hours(5)),season 
      )] |>
    melt(
      id.vars = c('nerc_adj','hour','season')
    )|>
    ggplot(aes(x = hour, y = value/1e3, color = variable, linetype = variable)) +
    geom_line(size = 1.1) + 
    facet_grid(
      cols = vars(season), 
      rows = vars(nerc_adj), 
      scales = "free"
    ) +
    theme_minimal() + 
    scale_color_brewer(
      palette = "Dark2", 
      labels = c('Actual','Predicted')
    ) + 
    scale_linetype_manual(
      values = c('twodash','solid'), 
      labels = c('Actual','Predicted')
    )+
    labs(
      x = "Hour (EST)", 
      y = "Electricity Production (GWh)",
      linetype = '',color = ''
    ) + 
    theme(legend.position = 'bottom')
  #model_fit_region_hour_season_p
  ggsave(
    plot = model_fit_region_hour_season_p,
    filename = here(
      "figures/electricity-generation/model-fit-region-hour-season-2020.jpeg"
    ),
    bg = 'white',
    width = 8, height = 9
  )

# Fuel mix vs load for each interconnection -----------------------------------
  ic_fuel_gen_p = 
    merge(
      ic_fuel_gen_dt[datetime_utc > ymd_hms('2020-01-01 07:00:00')],
      all_load_dt[,
        .(excess_load = sum(excess_load)), 
        keyby = .(ic, datetime_utc)
      ],
      by = c('ic','datetime_utc')
    )[,':='(
      pct_gen_actual = net_generation_mwh/sum(net_generation_mwh), 
      pct_gen_predicted = fit_net_generation_mwh/sum(fit_net_generation_mwh)), 
      by = .(ic, datetime_utc)
    ] |>
    melt(
      id.vars = c('ic','fuel_category_n','datetime_utc','excess_load'),
      measure = patterns('pct_gen')
    ) |>
    ggplot(aes(
      x = excess_load/1e3, 
      y = value, 
      linetype = variable,
      color = fuel_category_n, 
    )) + 
    #geom_point(alpha = 0.2) +
    geom_smooth(se = FALSE) + 
    scale_color_brewer(name = 'Fuel Category',palette = 'Dark2') + 
    scale_y_continuous(labels = scales::label_percent()) +
    labs(
      x = 'Excess Load (GWh)',
      y = 'Percent of Electricity Production'
    ) + 
    scale_linetype_manual(
      name = 'Data',
      labels = c('Actual','Predicted'),
      values = c('twodash','solid')
    ) +
    facet_wrap(~ic, scales = 'free_x')+ 
    guides(
      linetype=guide_legend(keywidth = 4, keyheight = 1, override.aes = list(color = 'black')),
      colour=guide_legend(keywidth = 4, keyheight = 1)
    ) + 
    theme_minimal() + 
    theme(
      legend.position = 'bottom',
      legend.box = 'vertical',
      strip.text.x = element_text(size = 12)
    ) 
  #ic_fuel_gen_p
  ggsave(
    plot = ic_fuel_gen_p,
    filename = here("figures/electricity-generation/model-fit-ic-fuel-mix-2020.jpeg"),
    bg = 'white',
    width = 11, height = 5
  )
  # Now fuel mix by region, hopefully explains the variance in 
  region_fuel_gen_p = 
    merge(
      nerc_fuel_gen_dt[
        nerc_adj != 'TRE' & datetime_utc > ymd_hms('2020-01-01 07:00:00')
      ],
      all_load_dt[,
        .(excess_load = sum(excess_load)), 
        keyby = .(ic, datetime_utc)
      ],
      by = c('ic','datetime_utc')
    )[,':='(
      pct_gen_actual = net_generation_mwh/sum(net_generation_mwh), 
      pct_gen_predicted = fit_net_generation_mwh/sum(fit_net_generation_mwh)), 
      by = .(ic, nerc_adj, datetime_utc)
    ] |>
    melt(
      id.vars = c('ic','nerc_adj','fuel_category_n','datetime_utc','excess_load'),
      measure = patterns('pct_gen')
    ) |>
    ggplot(aes(
      x = excess_load/1e3, 
      y = value, 
      linetype = variable,
      color = fuel_category_n, 
    )) + 
    #geom_point(alpha = 0.2) +
    geom_smooth(se = FALSE) + 
    scale_color_brewer(name = 'Fuel Category',palette = 'Dark2') + 
    scale_y_continuous(labels = scales::label_percent()) +
    labs(
      x = 'Excess Load (GWh)',
      y = 'Percent of Electricity Production'
    ) + 
    scale_linetype_manual(
      name = 'Data',
      labels = c('Actual','Predicted'),
      values = c('twodash','solid')
    ) +
    facet_wrap(~nerc_adj, scales = 'free_x')+ 
    guides(
      linetype=guide_legend(keywidth = 4, keyheight = 1, override.aes = list(color = 'black')),
      colour=guide_legend(keywidth = 4, keyheight = 1)
    ) + 
    theme_minimal() + 
    theme(
      legend.position = 'bottom',
      legend.box = 'vertical'
    ) 
  #region_fuel_gen_p
  ggsave(
    plot = region_fuel_gen_p,
    filename = here("figures/electricity-generation/model-fit-region-fuel-mix-2020.jpeg"),
    bg = 'white',
    width = 11, height = 8
  )


# Plant level fit plots -------------------------------------------------------
summarize_plant_gen_dt_plant = function(x){
  plant_gen_dt =   
    read.fst(
      path = here(nerc_fit_path, paste0(str_to_lower(x),'.fst')),
      as.data.table = TRUE
    ) |>
    merge(
      plant_dt, 
      by = 'plant_id_eia'
    ) |>
    setnames(
      old = c('co2e_mass_lb_for_electricity','so2_mass_lb_for_electricity','nox_mass_lb_for_electricity'),
      new = c('emissions_tons_co2e','emissions_tons_so2','emissions_tons_nox')
    ) |>
    agg_fit(
      by_cols = c(
        'plant_id_eia','nerc_adj','fuel_category','year'
      )
    )    
  return(plant_gen_dt)
}

plant_gen_dt = 
  map_dfr(
    nerc_already_run,
    summarize_plant_gen_dt_plant
  )
plant_gen_dt[,
  ic := fcase(
    nerc_adj == 'TRE', 'Texas',
    nerc_adj %in% c('WECC','CAL'), 'West',
    default = 'East'
  ) |> factor(levels = c('West','Texas','East'))
]
# Total production hours predicted vs actual 
plant_hours_p =
    plant_gen_dt[year == 2020,.(
      Actual = prod_hours,
      Predicted = fit_prod_hours,
      mwh = net_generation_mwh), 
      by = .(plant_id_eia, ic)
    ] |>
    ggplot(aes(x = Actual, y = Predicted, size = mwh, weight = mwh)) + 
    geom_point(alpha = 0.3, color = 'gray30', shape = 19) +
    geom_smooth() +
    geom_abline(intercept = 0, slope = 1, linetype = "dashed") +
    # facet_wrap(
    #   ~ic,
    #   scales = "free"
    # )+ 
    scale_x_continuous(
      name = 'Actual Production Hours', 
      labels = scales::label_percent()
    )+ 
    scale_y_continuous(
      name = 'Predicted Production Hours', 
      labels = scales::label_percent()
    ) + 
    theme_minimal() +
    guides(size = 'none')
  #plant_hours_p
  ggsave(
    plot = plant_hours_p,
    filename = here("figures/electricity-generation/model-fit-plant-hour-2020.jpeg"),
    bg = 'white',
    width = 8, height = 6
  )
  
  # Total production by plant ---------
  plant_prod_p =
    plant_gen_dt[year == 2020,.(
      Actual = net_generation_mwh,
      Predicted = fit_net_generation_mwh), 
      by = .(plant_id_eia, ic)
    ] |>
    ggplot(aes(x = Actual, y = Predicted)) + 
    geom_point(alpha = 0.3) + #alpha = 0.3, color = 'gray30', shape = 19
    #geom_smooth() +
    geom_abline(intercept = 0, slope = 1, linetype = "dashed") +
    # facet_wrap(
    #   ~ic,
    #   scales = "free"
    # )+ 
    scale_x_continuous(
      name = 'Actual Production (GWh)' ,
      labels = scales::label_number(scale = 1e-3, big.mark = ',')
    )+ 
    scale_y_continuous(
      name = 'Predicted Production (GWh)',
      labels = scales::label_number(scale = 1e-3, big.mark = ',')
    ) + 
    theme_minimal()
  #plant_prod_p
  ggsave(
    plot = plant_prod_p,
    filename = here("figures/electricity-generation/model-fit-plant-prod-2020.jpeg"),
    bg = 'white',
    width = 8, height = 6
  )
    
  # Total production by plant-month ---------
  summarize_plant_gen_dt_plant_month = function(x){
    plant_gen_dt =   
      read.fst(
        path = here(nerc_fit_path, paste0(str_to_lower(x),'.fst')),
        as.data.table = TRUE
      ) |>
      merge(
        plant_dt, 
        by = 'plant_id_eia'
      )  |>
      setnames(
        old = c('co2e_mass_lb_for_electricity','so2_mass_lb_for_electricity','nox_mass_lb_for_electricity'),
        new = c('emissions_tons_co2e','emissions_tons_so2','emissions_tons_nox')
      ) %>%
      .[,month := month(datetime_utc)] %>%
      agg_fit(
        by_cols = c(
          'plant_id_eia','nerc_adj','fuel_category','year', 'month'
        )
      )    
  return(plant_gen_dt)
}

  plant_month_gen_dt = 
    map_dfr(
      nerc_already_run,
      summarize_plant_gen_dt_plant_month
    )
plant_month_gen_dt[,
  ic := fcase(
    nerc_adj == 'TRE', 'Texas',
    nerc_adj %in% c('WECC','CAL'), 'West',
    default = 'East'
  ) |> factor(levels = c('West','Texas','East'))
]
  plant_month_prod_p =
    plant_month_gen_dt[year == 2020,.(
      Actual = net_generation_mwh,
      Predicted = fit_net_generation_mwh), 
      by = .(plant_id_eia, month, ic)
    ] |>
    ggplot(aes(x = Actual, y = Predicted)) + 
    geom_point(alpha = 0.3) + #alpha = 0.3, color = 'gray30', shape = 19
    #geom_smooth() +
    geom_abline(intercept = 0, slope = 1, linetype = "dashed") +
    #facet_wrap(
    #  ~ic,
    #  scales = "free"
    #)+ 
    scale_x_continuous(
      name = 'Actual Production (MWh)' ,
      labels = scales::label_number(scale = 1, big.mark = ',')
    )+ 
    scale_y_continuous(
      name = 'Predicted Production (MWh)',
      labels = scales::label_number(scale = 1, big.mark = ',')
    ) + 
    theme_minimal()
  #plant_month_prod_p
  ggsave(
    plot = plant_month_prod_p,
    filename = here("figures/electricity-generation/model-fit-plant-prod-month-2020.jpeg"),
    bg = 'white',
    width = 8, height = 6
  )
